package org.bdware.analysis;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.objectweb.asm.*;
import org.objectweb.asm.tree.*;

import java.util.*;

public abstract class CFGraph {
    private static final Logger LOGGER = LogManager.getLogger(CFGraph.class);
    private final Map<BasicBlock, Set<BasicBlock>> preBlock;
    private final Map<Label, BasicBlock> labelToBB;
    private final Map<Label, Integer> labelOrder = new HashMap<>();
    protected List<BasicBlock> basicBlocks;
    protected Map<BasicBlock, Set<BasicBlock>> sucBlock;
    MethodNode methodNode;
    // Pass2: build edges
    // type2(canThrow) and catchLabels
    // type3 and target labels
    // type4 and catchLabels
//    private List<TryCatchBlockNode> tryCacheList;
    // Pass1: build basic blocks
    // create a new block when:
    // 1.starts with Label(can jump)
    // 2.ends with CanReturn/CanThrow Insn
    // 3.ends with CanBranch(IF/Switch)
    // 4.ends with MethodInvoke
    // At the same time build edges between:
    // type1 and pre
    // type2 and endBlock
    // type3 and suc
    // type4 and endBlock
    // type4 and suc

    public CFGraph(MethodNode mn) {
        basicBlocks = new ArrayList<>();
        preBlock = new HashMap<>();
        sucBlock = new HashMap<>();
        labelToBB = new HashMap<>();
        methodNode = mn;
        buildBasicBlock(mn.instructions, mn.tryCatchBlocks);
    }

    private void addEdge(BasicBlock pre, BasicBlock suc) {
        Set<BasicBlock> pres = getPreBlocks(suc);
        pres.add(pre);
        Set<BasicBlock> sucs = getSucBlocks(pre);
        sucs.add(suc);
    }

    public Set<BasicBlock> getSucBlocks(BasicBlock pre) {
        if (null == sucBlock.get(pre)) {
            Set<BasicBlock> ret = new HashSet<>();
            sucBlock.put(pre, ret);
            return ret;
        }
        return sucBlock.get(pre);
    }

    public Set<BasicBlock> getPreBlocks(BasicBlock suc) {
        if (preBlock.get(suc) == null) {
            Set<BasicBlock> ret = new HashSet<>();
            preBlock.put(suc, ret);
            return ret;
        }
        return preBlock.get(suc);
    }

    public abstract BasicBlock getBasicBlock(int id);

    private void buildBasicBlock(InsnList instructions, List<TryCatchBlockNode> tryCatchBlocks) {
        InsnPass1Visitor visitor = new InsnPass1Visitor(Opcodes.ASM4);
        visitor.currBlock = getBasicBlock(0);
        basicBlocks.add(visitor.currBlock);
        for (int i = 0; i < instructions.size(); i++) {
            AbstractInsnNode insn = instructions.get(i);
            visitor.setCurrInsn(insn);
            insn.accept(visitor);
        }
        visitor.visitEnd();
        List<BasicBlock> merged = new ArrayList<>();
        for (BasicBlock bb : basicBlocks)
            if (bb.list.size() > 0)
                merged.add(bb);
        basicBlocks = merged;
        InsnPass2Visitor pass2 = new InsnPass2Visitor(Opcodes.ASM4);
        pass2.tryCatchBlocks = tryCatchBlocks;
        for (int i = 0; i < basicBlocks.size(); i++) {
            pass2.blockid = i;
            BasicBlock bb = basicBlocks.get(i);
            bb.blockID = i;
            if (bb.size() > 0)
                if (bb.list.get(0) instanceof LabelNode) {
                    pass2.preLabel = ((LabelNode) bb.list.get(0)).getLabel();
                }
            for (AbstractInsnNode node : bb.list) {
                node.accept(pass2);
            }
        }
        int preLine = -1;
        for (BasicBlock block : basicBlocks) {
            if (block.lineNum != -1) {
                preLine = block.lineNum;
            } else
                block.lineNum = preLine;
        }
    }

    public void printSelf() {
        InsnPrinter printer = new InsnPrinter(Opcodes.ASM4, System.out);
        printer.setLabelOrder(getLabelOrder());
        LOGGER.info("isStatic: " + ((methodNode.access & Opcodes.ACC_STATIC) > 0)
                + "\tMethod:" + methodNode.name + "  " + methodNode.desc);
        LOGGER.info(methodNode.maxLocals + "  " + methodNode.maxStack);
        StringBuilder log = new StringBuilder();
        for (BasicBlock bb : basicBlocks) {
            log.append("B").append(bb.blockID);
            if (getSucBlocks(bb).size() > 0) {
                log.append(" -->");
            }
            for (BasicBlock suc : getSucBlocks(bb)) {
                log.append(" B").append(suc.blockID);
            }
            for (AbstractInsnNode an : bb.list) {
                an.accept(printer);
            }
            log.append("\n");
        }
        LOGGER.info(log.substring(0, log.length() - 1));
    }

    public Map<Label, Integer> getLabelOrder() {
        return labelOrder;
    }

    public BasicBlock getBasicBlockAt(int i) {
        return basicBlocks.get(i);
    }

    public int getBasicBlockSize() {
        return basicBlocks.size();
    }

    public MethodNode getMethodNode() {
        return methodNode;
    }

    public BasicBlock getBasicBlockByLabel(Label l) {
        return labelToBB.get(l);
    }

    private class InsnPass1Visitor extends MethodVisitor {
        int count = 0;
        AbstractInsnNode currInsn;
        BasicBlock currBlock;

        // -------------------------------------------------------------------------
        // Normal instructions
        // -------------------------------------------------------------------------
        BasicBlock endBlock = getEndBlock();

        public InsnPass1Visitor(int api) {
            super(api);
        }

        public void visitFrame(int type, int nLocal, Object[] local, int nStack, Object[] stack) {
            currBlock.add(currInsn);
        }

        /**
         * Visits a zero operand instruction.
         *
         * @param opcode the opcode of the instruction to be visited. This opcode is
         *               either NOP, ACONST_NULL, ICONST_M1, ICONST_0, ICONST_1,
         *               ICONST_2, ICONST_3, ICONST_4, ICONST_5, LCONST_0, LCONST_1,
         *               FCONST_0, FCONST_1, FCONST_2, DCONST_0, DCONST_1, IALOAD,
         *               LALOAD, FALOAD, DALOAD, AALOAD, BALOAD, CALOAD, SALOAD,
         *               IASTORE, LASTORE, FASTORE, DASTORE, AASTORE, BASTORE, CASTORE,
         *               SASTORE, POP, POP2, DUP, DUP_X1, DUP_X2, DUP2, DUP2_X1,
         *               DUP2_X2, SWAP, IADD, LADD, FADD, DADD, ISUB, LSUB, FSUB, DSUB,
         *               IMUL, LMUL, FMUL, DMUL, IDIV, LDIV, FDIV, DDIV, IREM, LREM,
         *               FREM, DREM, INEG, LNEG, FNEG, DNEG, ISHL, LSHL, ISHR, LSHR,
         *               IUSHR, LUSHR, IAND, LAND, IOR, LOR, IXOR, LXOR, I2L, I2F, I2D,
         *               L2I, L2F, L2D, F2I, F2L, F2D, D2I, D2L, D2F, I2B, I2C, I2S,
         *               LCMP, FCMPL, FCMPG, DCMPL, DCMPG, IRETURN, LRETURN, FRETURN,
         *               DRETURN, ARETURN, RETURN, ARRAYLENGTH, ATHROW, MONITORENTER, or
         *               MONITOREXIT.
         */
        public void visitInsn(int opcode) {
            currBlock.add(currInsn);
            if (OpInfo.ops[opcode].canReturn()) {
                addEdge(currBlock, endBlock);
                currBlock = getBasicBlock(basicBlocks.size());
                basicBlocks.add(currBlock);
            }
        }

        /**
         * Visits an instruction with a single int operand.
         *
         * @param opcode  the opcode of the instruction to be visited. This opcode is
         *                either BIPUSH, SIPUSH or NEWARRAY.
         * @param operand the operand of the instruction to be visited.<br>
         *                When opcode is BIPUSH, operand value should be between
         *                Byte.MIN_VALUE and Byte.MAX_VALUE.<br>
         *                When opcode is SIPUSH, operand value should be between
         *                Short.MIN_VALUE and Short.MAX_VALUE.<br>
         *                When opcode is NEWARRAY, operand value should be one of
         *                {@link Opcodes#T_BOOLEAN}, {@link Opcodes#T_CHAR},
         *                {@link Opcodes#T_FLOAT}, {@link Opcodes#T_DOUBLE},
         *                {@link Opcodes#T_BYTE}, {@link Opcodes#T_SHORT},
         *                {@link Opcodes#T_INT} or {@link Opcodes#T_LONG}.
         */
        public void visitIntInsn(int opcode, int operand) {
            currBlock.add(currInsn);
        }

        /**
         * Visits a local variable instruction. A local variable instruction is an
         * instruction that loads or stores the value of a local variable.
         *
         * @param opcode the opcode of the local variable instruction to be visited.
         *               This opcode is either ILOAD, LLOAD, FLOAD, DLOAD, ALOAD,
         *               ISTORE, LSTORE, FSTORE, DSTORE, ASTORE or RET.
         * @param var    the operand of the instruction to be visited. This operand is
         *               the index of a local variable.
         */
        public void visitVarInsn(int opcode, int var) {
            currBlock.add(currInsn);
        }

        /**
         * Visits a type instruction. A type instruction is an instruction that takes
         * the internal name of a class as parameter.
         *
         * @param opcode the opcode of the type instruction to be visited. This opcode
         *               is either NEW, ANEWARRAY, CHECKCAST or INSTANCEOF.
         * @param type   the operand of the instruction to be visited. This operand must
         *               be the internal name of an object or array class (see
         *               {@link Type#getInternalName() getInternalName}).
         */
        public void visitTypeInsn(int opcode, String type) {
            currBlock.add(currInsn);
        }

        /**
         * Visits a field instruction. A field instruction is an instruction that loads
         * or stores the value of a field of an object.
         *
         * @param opcode the opcode of the type instruction to be visited. This opcode
         *               is either GETSTATIC, PUTSTATIC, GETFIELD or PUTFIELD.
         * @param owner  the internal name of the field's owner class (see
         *               {@link Type#getInternalName() getInternalName}).
         * @param name   the field's name.
         * @param desc   the field's descriptor (see {@link Type Type}).
         */
        public void visitFieldInsn(int opcode, String owner, String name, String desc) {
            currBlock.add(currInsn);
        }

        /**
         * Visits a method instruction. A method instruction is an instruction that
         * invokes a method.
         *
         * @param opcode the opcode of the type instruction to be visited. This opcode
         *               is either INVOKEVIRTUAL, INVOKESPECIAL, INVOKESTATIC or
         *               INVOKEINTERFACE.
         * @param owner  the internal name of the method's owner class (see
         *               {@link Type#getInternalName() getInternalName}).
         * @param name   the method's name.
         * @param desc   the method's descriptor (see {@link Type Type}).
         */
        public void visitMethodInsn(int opcode, String owner, String name, String desc) {
            currBlock.add(currInsn);
            addEdge(currBlock, endBlock);
            BasicBlock nextBlock = getBasicBlock(basicBlocks.size());
            addEdge(currBlock, nextBlock);
            currBlock = nextBlock;
            basicBlocks.add(currBlock);
            // TODO add edges to try catch blocks!
        }

        // -------------------------------------------------------------------------
        // Special instructions
        // -------------------------------------------------------------------------

        /**
         * Visits an invokedynamic instruction.
         *
         * @param name    the method's name.
         * @param desc    the method's descriptor (see {@link Type Type}).
         * @param bsm     the bootstrap method.
         * @param bsmArgs the bootstrap method constant arguments. Each argument must be
         *                an {@link Integer}, {@link Float}, {@link Long},
         *                {@link Double}, {@link String}, {@link Type} or {@link Handle}
         *                value. This method is allowed to modify the content of the
         *                array so a caller should expect that this array may change.
         */
        public void visitInvokeDynamicInsn(String name, String desc, Handle bsm, Object... bsmArgs) {

            // TODO add edges to try catch blocks!
            currBlock.add(currInsn);
            addEdge(currBlock, endBlock);
            BasicBlock nextBlock = getBasicBlock(basicBlocks.size());
            addEdge(currBlock, nextBlock);
            currBlock = nextBlock;
            basicBlocks.add(currBlock);
        }

        /**
         * Visits a jump instruction. A jump instruction is an instruction that may jump
         * to another instruction.
         *
         * @param opcode the opcode of the type instruction to be visited. This opcode
         *               is either IFEQ, IFNE, IFLT, IFGE, IFGT, IFLE, IF_ICMPEQ,
         *               IF_ICMPNE, IF_ICMPLT, IF_ICMPGE, IF_ICMPGT, IF_ICMPLE,
         *               IF_ACMPEQ, IF_ACMPNE, GOTO, JSR, IFNULL or IFNONNULL.
         * @param label  the operand of the instruction to be visited. This operand is a
         *               label that designates the instruction to which the jump
         *               instruction may jump.
         */
        public void visitJumpInsn(int opcode, Label label) {
            currBlock.add(currInsn);
            BasicBlock nextBlock = getBasicBlock(basicBlocks.size());

            if (!OpInfo.ops[opcode].toString().contains("goto"))
                addEdge(currBlock, nextBlock);
            currBlock = nextBlock;
            basicBlocks.add(currBlock);
        }

        /**
         * Visits a label. A label designates the instruction that will be visited just
         * after it.
         *
         * @param label a {@link Label Label} object.
         */
        public void visitLabel(Label label) {
            getLabelOrder().put(label, count++);
            if (currBlock.size() > 0) {
                BasicBlock pre = currBlock;
                currBlock = getBasicBlock(basicBlocks.size());
                basicBlocks.add(currBlock);
                addEdge(pre, currBlock);
            }
            currBlock.add(currInsn);
            labelToBB.put(label, currBlock);
        }

        /**
         * Visits a LDC instruction. Note that new constant types may be added in future
         * versions of the Java Virtual Machine. To easily detect new constant types,
         * implementations of this method should check for unexpected constant types,
         * like this:
         *
         * <pre>
         * if (cst instanceof Integer) {
         * 	// ...
         * } else if (cst instanceof Float) {
         * 	// ...
         * } else if (cst instanceof Long) {
         * 	// ...
         * } else if (cst instanceof Double) {
         * 	// ...
         * } else if (cst instanceof String) {
         * 	// ...
         * } else if (cst instanceof Type) {
         * 	int sort = ((Type) cst).getSort();
         * 	if (sort == Type.OBJECT) {
         * 		// ...
         *    } else if (sort == Type.ARRAY) {
         * 		// ...
         *    } else if (sort == Type.METHOD) {
         * 		// ...
         *    } else {
         * 		// throw an exception
         *    }
         * } else if (cst instanceof Handle) {
         * 	// ...
         * } else {
         * 	// throw an exception
         * }
         * </pre>
         *
         * @param cst the constant to be loaded on the stack. This parameter must be a
         *            non null {@link Integer}, a {@link Float}, a {@link Long}, a
         *            {@link Double}, a {@link String}, a {@link Type} of OBJECT or
         *            ARRAY sort for <tt>.class</tt> constants, for classes whose
         *            version is 49.0, a {@link Type} of METHOD sort or a {@link Handle}
         *            for MethodType and MethodHandle constants, for classes whose
         *            version is 51.0.
         */
        public void visitLdcInsn(Object cst) {
            currBlock.add(currInsn);
        }

        /**
         * Visits an IINC instruction.
         *
         * @param var       index of the local variable to be incremented.
         * @param increment amount to increment the local variable by.
         */
        public void visitIincInsn(int var, int increment) {
            currBlock.add(currInsn);
        }
        // -------------------------------------------------------------------------
        // Exceptions table entries, debug information, max stack and max locals
        // -------------------------------------------------------------------------

        /**
         * Visits a TABLESWITCH instruction.
         *
         * @param min    the minimum key value.
         * @param max    the maximum key value.
         * @param dflt   beginning of the default handler block.
         * @param labels beginnings of the handler blocks. <tt>labels[i]</tt> is the
         *               beginning of the handler block for the <tt>min + i</tt> key.
         */
        public void visitTableSwitchInsn(int min, int max, Label dflt, Label... labels) {
            currBlock.add(currInsn);
            currBlock = getBasicBlock(basicBlocks.size());
            basicBlocks.add(currBlock);
        }

        /**
         * Visits a LOOKUPSWITCH instruction.
         *
         * @param dflt   beginning of the default handler block.
         * @param keys   the values of the keys.
         * @param labels beginnings of the handler blocks. <tt>labels[i]</tt> is the
         *               beginning of the handler block for the <tt>keys[i]</tt> key.
         */
        public void visitLookupSwitchInsn(Label dflt, int[] keys, Label[] labels) {
            currBlock.add(currInsn);
            currBlock = getBasicBlock(basicBlocks.size());
            basicBlocks.add(currBlock);
        }

        /**
         * Visits a MULTIANEWARRAY instruction.
         *
         * @param desc an array type descriptor (see {@link Type Type}).
         * @param dims number of dimensions of the array to allocate.
         */
        public void visitMultiANewArrayInsn(String desc, int dims) {
            currBlock.add(currInsn);
        }

        /**
         * Visits a local variable declaration.
         *
         * @param name      the name of a local variable.
         * @param desc      the type descriptor of this local variable.
         * @param signature the type signature of this local variable. May be
         *                  <tt>null</tt> if the local variable type does not use
         *                  generic types.
         * @param start     the first instruction corresponding to the scope of this
         *                  local variable (inclusive).
         * @param end       the last instruction corresponding to the scope of this
         *                  local variable (exclusive).
         * @param index     the local variable's index.
         * @throws IllegalArgumentException if one of the labels has not already been
         *                                  visited by this visitor (by the
         *                                  {@link #visitLabel visitLabel} method).
         */
        public void visitLocalVariable(String name, String desc, String signature, Label start, Label end, int index) {
            currBlock.add(currInsn);
        }

        /**
         * Visits a line number declaration.
         *
         * @param line  a line number. This number refers to the source file from which
         *              the class was compiled.
         * @param start the first instruction corresponding to this line number.
         * @throws IllegalArgumentException if <tt>start</tt> has not already been
         *                                  visited by this visitor (by the
         *                                  {@link #visitLabel visitLabel} method).
         */
        public void visitLineNumber(int line, Label start) {
            currBlock.add(currInsn);
        }

        public void setCurrInsn(AbstractInsnNode insn) {
            currInsn = insn;
        }

        private BasicBlock getEndBlock() {
            BasicBlock ret = getBasicBlock(-1);
            Label l = new Label();
            ret.add(new LabelNode(l));
            labelToBB.put(l, ret);
            return ret;
        }

        public void visitEnd() {
            basicBlocks.add(endBlock);
            getLabelOrder().put(((LabelNode) (endBlock.list.get(0))).getLabel(), count++);

        }
    }

    private class InsnPass2Visitor extends MethodVisitor {
        int blockid;
        Label preLabel;
        List<TryCatchBlockNode> tryCatchBlocks;

        public InsnPass2Visitor(int api) {
            super(api);
        }

        public void addTryCatchNodes() {
            for (TryCatchBlockNode node : tryCatchBlocks) {
                Label l1 = node.start.getLabel();
                Label l2 = node.end.getLabel();
                Label l3 = node.handler.getLabel();
                if (inLabel(l1, l2, preLabel))
                    addEdge(labelToBB.get(preLabel), labelToBB.get(l3));
            }
        }

        @Override
        public void visitLineNumber(int line, Label start) {
            BasicBlock currBlock = basicBlocks.get(blockid);
            currBlock.setLineNum(line);
        }

        @Override
        public void visitJumpInsn(int opcode, Label label) {
            BasicBlock currBlock = basicBlocks.get(blockid);
            BasicBlock targetBlock = labelToBB.get(label);
            addEdge(currBlock, targetBlock);
        }

        @Override
        public void visitTableSwitchInsn(int min, int max, Label dflt, Label... labels) {
            BasicBlock currBlock = basicBlocks.get(blockid);
            BasicBlock targetBlock = labelToBB.get(dflt);
            addEdge(currBlock, targetBlock);
            for (Label label : labels) {
                targetBlock = labelToBB.get(label);
                addEdge(currBlock, targetBlock);
            }
        }

        @Override

        public void visitLookupSwitchInsn(Label dflt, int[] keys, Label[] labels) {
            BasicBlock currBlock = basicBlocks.get(blockid);
            BasicBlock targetBlock = labelToBB.get(dflt);
            addEdge(currBlock, targetBlock);
            for (Label label : labels) {
                targetBlock = labelToBB.get(label);
                addEdge(currBlock, targetBlock);
            }
        }

        @Override
        public void visitInvokeDynamicInsn(String name, String desc, Handle bsm, Object... bsmArgs) {
            addTryCatchNodes();
        }

        private boolean inLabel(Label start, Label end, Label query) {
            int s = getLabelOrder().get(start);
            int e = getLabelOrder().get(end);
            int m = getLabelOrder().get(query);
            return (m >= s && m < e);
        }

        public void visitMethodInsn(int opcode, String owner, String name, String desc) {
            addTryCatchNodes();
        }

        public void visitInsn(int opcode) {
            if (OpInfo.ops[opcode].toString().contains("throw")) {
                addTryCatchNodes();
            }
        }
    }
}
